{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2c06e6d6",
   "metadata": {},
   "source": [
    "# Efficient Fine Tuning of Quantized LLMs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f03853e8",
   "metadata": {},
   "source": [
    "### Installing required packages\n",
    "#### Setting up virtual env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fafd4343",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conda create -n qlora python=3.10\n",
    "# !conda activate qlora\n",
    "# pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
    "\n",
    "# pip install scipy\n",
    "# pip install -q -U bitsandbytes\n",
    "# pip install -q -U git+https://github.com/huggingface/transformers.git\n",
    "# pip install -q -U git+https://github.com/huggingface/peft.git\n",
    "# pip install -q -U git+https://github.com/huggingface/accelerate.git\n",
    "# pip install -q datasets\n",
    "# pip install einops  # needed for falcon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "82db245c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import bitsandbytes\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9f11be70",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/data/finetuning/qlora'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.getcwd()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a147426b",
   "metadata": {},
   "source": [
    "First let's load the model we are going to use - falcon-7b! Note that the model itself is around 14GB in half precision"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6efe865",
   "metadata": {},
   "source": [
    "## Loading LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "30c7690f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7fa95054752b418db8f1340ab143bc68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#model_id = \"EleutherAI/gpt-neox-20b\"\n",
    "# model_id=\"google/flan-ul2\" -- doesn't work\n",
    "model_id=\"tiiuae/falcon-7b\"\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
    "# how does device_map interact with CUDA_VISIBLE_DEVICES\n",
    "# maybe device_map=\"auto\"?\n",
    "# model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={\"\":0})\n",
    "#model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map=\"auto\")\n",
    "# for falcon\n",
    "# model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map=\"auto\", trust_remote_code=True)\n",
    "\n",
    "# Load the model from local\n",
    "model = AutoModelForCausalLM.from_pretrained('/data/finetuning/lit-parrot/checkpoints/tiiuae/falcon-7b/', quantization_config=bnb_config, device_map=\"auto\", local_files_only=True, trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8986ce20",
   "metadata": {},
   "source": [
    "## Pre-processing Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f21420f9",
   "metadata": {},
   "source": [
    "Then we have to apply some preprocessing to the model to prepare it for training. For that use the prepare_model_for_kbit_training method from PEFT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63eb17cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import prepare_model_for_kbit_training\n",
    "model.gradient_checkpointing_enable()\n",
    "model = prepare_model_for_kbit_training(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5a1be4eb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 9437184 || all params: 3618182016 || trainable%: 0.260826679207064\n"
     ]
    }
   ],
   "source": [
    "def print_trainable_parameters(model):\n",
    "    \"\"\"\n",
    "    Prints the number of trainable parameters in the model.\n",
    "    \"\"\"\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    print(\n",
    "        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n",
    "    )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4d13235",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Regularization strength (lora_alpha)\n",
    "# Dropout probability (lora_dropout)\n",
    "# target modules to be compressed (target_modules).\n",
    "from peft import LoraConfig, get_peft_model\n",
    "config = LoraConfig(\n",
    "    r=32,\n",
    "    lora_alpha=32,\n",
    "    target_modules=[\"query_key_value\"],\n",
    "    lora_dropout=0.2,\n",
    "    bias=\"none\",\n",
    "    task_type=\"CAUSAL_LM\"\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, config)\n",
    "print_trainable_parameters(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2019432c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "781faf94",
   "metadata": {},
   "source": [
    "## Loading Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08aebc4c",
   "metadata": {},
   "source": [
    "Let's load a dolly instruct dataset, to fine tune our model on famous quotes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7aae75b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "# data = load_dataset(\"Abirate/english_quotes\")\n",
    "data = load_dataset('json', data_files='dolly-instruct-dataset.json')\n",
    "data = data.map(lambda samples: tokenizer(samples[\"instruction\"]), batched=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "139bf5f1",
   "metadata": {},
   "source": [
    "## Fine-tuning model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f343e8de",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset json (/root/.cache/huggingface/datasets/json/default-061f7d4818b7ea6d/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7a3a3b31e404b44a24610f869377fb3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-061f7d4818b7ea6d/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96/cache-9546399c3240be86.arrow\n",
      "You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='30' max='30' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [30/30 03:05, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>2.829800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>2.829800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>2.801000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>2.709800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>2.581100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>2.421500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>2.250600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>2.073200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>1.867300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>1.620800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>1.336400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>1.037300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.759400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.538500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.369000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.259700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.202200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.158500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.127600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.110800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.095000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.083300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.073600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.066400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.056500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.049800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.044600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.040100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.037100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.035200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=30, training_loss=0.9821900140494109, metrics={'train_runtime': 193.0904, 'train_samples_per_second': 93.221, 'train_steps_per_second': 0.155, 'total_flos': 3534702046848000.0, 'train_loss': 0.9821900140494109, 'epoch': 10.0})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import transformers\n",
    "# needed for gpt-neo-x tokenizer\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "trainer = transformers.Trainer(\n",
    "    model=model,\n",
    "    train_dataset=data[\"train\"],\n",
    "    args=transformers.TrainingArguments(\n",
    "        per_device_train_batch_size=100,\n",
    "        gradient_accumulation_steps=6,\n",
    "        warmup_steps=2,\n",
    "        max_steps=30, #10,\n",
    "        learning_rate=2e-4,\n",
    "        fp16=True,\n",
    "        logging_steps=1,\n",
    "        output_dir=\"outputs2\",\n",
    "        optim=\"paged_adamw_8bit\",\n",
    "        report_to=\"none\"  # turns off wandb -> def. turn off for CCC.\n",
    "    ),\n",
    "    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
    ")\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74755fd7",
   "metadata": {},
   "source": [
    "## Saving Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8061bd05",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(\"outputs2/\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a320ee4",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8e3f9487",
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModelForCausalLM, get_peft_config\n",
    "\n",
    "compute_dtype = getattr(torch, \"float16\")\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    "    bnb_4bit_compute_dtype=compute_dtype,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    ")\n",
    "\n",
    "# model = AutoModelForCausalLM.from_pretrained(\n",
    "#     \"tiiuae/falcon-7b-instruct\", quantization_config=bnb_config, device_map=\"auto\", trust_remote_code=True\n",
    "# )\n",
    "\n",
    "# You can comment and un comment this line to either use base model \n",
    "# or the peft model during the inference.\n",
    "model = PeftModelForCausalLM.from_pretrained(model, 'outputs',local_files_only=True)\n",
    "\n",
    "tok = AutoTokenizer.from_pretrained('ybelkada/falcon-7b-sharded-bf16')\n",
    "tok.pad_token = tok.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "74ef0daa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ": How can I create an account?\n",
      ": How can I change my password?\n",
      ": How\n",
      "CPU times: user 3.07 s, sys: 2.6 ms, total: 3.07 s\n",
      "Wall time: 3.07 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "prompt = f\"\"\"\n",
    ": How can I create an account?\n",
    ":\n",
    "\"\"\".strip()\n",
    "\n",
    "encoding = tok(prompt, return_tensors=\"pt\")\n",
    "with torch.inference_mode():\n",
    "    outputs = model.generate(\n",
    "        input_ids=encoding.input_ids,\n",
    "        attention_mask=encoding.attention_mask,\n",
    "#         generation_config=generation_config,\n",
    "    )\n",
    "print(tok.decode(outputs[0], skip_special_tokens=True))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "78d8cb82",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:11 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Answer the question based only on the context below. Context: it a debate that modern football fans have been having over recent years Ive got to get your take on it who better Messi or Ronaldo Ronaldo come two or three years as the one of the best in Europe I think in the moment no doubt no doubt that Ronald was the best teacher watch out lava sure Bobby what going on is really hot it that rule very well so they need to get the ball they have enough nobody gonna pick now Ronaldo really good bright sharp start to the game I just said pick test from but Im sure in rise to the occasion chickens are trying to take him a different levels different inscription here we see him come on the ball terrific balance look look at that lovely little trick that he has and his profes 21 international and you saw him take apart England under-21 side one with a couple of those well Kresna has gone and seen to that really speed of the ball either freaking manipulating it make decision. Question: Who is the best football player? Answer: Ronaldo is the best football player.\n"
     ]
    }
   ],
   "source": [
    "prompt = f\"\"\"\n",
    "Answer the question based only on the context below. \\\n",
    "Context: it a debate that modern football fans have been having over recent years Ive got to get your take on it who better Messi or Ronaldo Ronaldo come two or three years as the one of the best in Europe I think in the moment no doubt no doubt that Ronald was the best teacher watch out lava sure Bobby what going on is really hot it that rule very well so they need to get the ball they have enough nobody gonna pick now Ronaldo really good bright sharp start to the game I just said pick test from but Im sure in rise to the occasion chickens are trying to take him a different levels different inscription here we see him come on the ball terrific balance look look at that lovely little trick that he has and his profes 21 international and you saw him take apart England under-21 side one with a couple of those well Kresna has gone and seen to that really speed of the ball either freaking manipulating it make decision. \\\n",
    "Question: Who is the best football player?\n",
    "\"\"\".strip()\n",
    "\n",
    "# inputs = tok(prompt, return_tensors=\"pt\")\n",
    "# inputs = inputs.to(0)\n",
    "# output = model.generate(inputs[\"input_ids\"])\n",
    "# tokenizer.decode(output[0].tolist())\n",
    "\n",
    "encoding = tok(prompt, return_tensors=\"pt\")\n",
    "with torch.inference_mode():\n",
    "    outputs = model.generate(\n",
    "        input_ids=encoding.input_ids,\n",
    "        attention_mask=encoding.attention_mask,\n",
    "        max_new_tokens=1000,\n",
    "#         generation_config=generation_config,\n",
    "    )\n",
    "print(tok.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23b33f28",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
